#!/usr/bin/env python

# priweavepng
# Weave selected channels from input PNG files into
# a multi-channel output PNG.

from __future__ import print_function

import argparse
import collections
import itertools
import re
import sys

from array import array

import png

"""
priweavepng file1.png [file2.png ...]

The `priweavepng` tool combines channels from the input images and
weaves a selection of those channels into an output image.

Conceptually an intermediate image is formed consisting of
all channels of all input images in the order given on the command line
and in the order of each channel in its image.
Then from 1 to 4 channels are selected and
an image is output with those channels.
The limit on the number of selected channels is
imposed by the PNG image format.

The `-c n` option selects channel `n`.
Further channels can be selected either by repeating the `-c` option,
or using a comma separated list.
For example `-c 3,2,1` will select channels 3, 2, and 1 in that order;
if the input is an RGB PNG, this will swop the Red and Blue channels.
The order is significant, the order in which the options are given is
the order of the output channels.
It is permissible, and sometimes useful
(for example, grey to colour expansion, see below),
to repeat the same channel.

If no `-c` option is used the default is
to select all of the input channels, up to the first 4.

`priweavepng` does not care about the meaning of the channels
and treats them as a matrix of values.

The numer of output channels determines the colour mode of the PNG file:
L (1-channel, Grey), LA (2-channel, Grey+Alpha),
RGB (3-channel, Red+Green+Blue), RGBA (4-channel, Red+Green+Blue+Alpha).

The `priweavepng` tool can be used for a variety of
channel building, swopping, and extraction effects:

Combine 3 grayscale images into RGB colour:
    priweavepng grey1.png grey2.png grey3.png

Swop Red and Blue channels in colour image:
    priweavepng -c 3 -c 2 -c 1 rgb.png

Extract Green channel as a greyscale image:
    priweavepng -c 2 rgb.png

Convert a greyscale image to a colour image (all grey):
    priweavepng -c 1 -c 1 -c 1 grey.png

Add alpha mask from a separate (greyscale) image:
    priweavepng rgb.png grey.png

Extract alpha mask into a separate (greyscale) image:
    priweavepng -c 4 rgba.png

Steal alpha mask from second file and add to first.
Note that the intermediate image in this example has 7 channels:
    priweavepng -c 1 -c 2 -c 3 -c 7 rgb.png rgba.png

Take Green channel from 3 successive colour images to make a new RGB image:
    priweavepng -c 2 -c 5 -c 8 rgb1.png rgb2.png rgb3.png

"""

Image = collections.namedtuple("Image", "rows info")

# For each channel in the intermediate raster,
# model:
# - image: the input image (0-based);
# - i: the channel index within that image (0-based);
# - bitdepth: the bitdepth of this channel.
Channel = collections.namedtuple("Channel", "image i bitdepth")


class Error(Exception):
    pass


def weave(out, args):
    """Stack the input PNG files and extract channels
    into a single output PNG.
    """

    paths = args.input

    if len(paths) < 1:
        raise Error("Required input is missing.")

    # List of Image instances
    images = []
    # Channel map. Maps from channel number (starting from 1)
    # to an (image_index, channel_index) pair.
    channel_map = dict()
    channel = 1

    for image_index, path in enumerate(paths):
        inp = png.cli_open(path)
        rows, info = png.Reader(file=inp).asDirect()[2:]
        rows = list(rows)
        image = Image(rows, info)
        images.append(image)
        # A later version of PyPNG may intelligently support
        # PNG files with heterogenous bitdepths.
        # For now, assumes bitdepth of all channels in image
        # is the same.
        channel_bitdepth = (image.info["bitdepth"],) * image.info["planes"]
        for i in range(image.info["planes"]):
            channel_map[channel + i] = Channel(image_index, i, channel_bitdepth[i])
        channel += image.info["planes"]

    assert channel - 1 == sum(image.info["planes"] for image in images)

    # If no channels, select up to first 4 as default.
    if not args.channel:
        args.channel = range(1, channel)[:4]

    out_channels = len(args.channel)
    if not (0 < out_channels <= 4):
        raise Error("Too many channels selected (must be 1 to 4)")
    alpha = out_channels in (2, 4)
    greyscale = out_channels in (1, 2)

    bitdepth = tuple(image.info["bitdepth"] for image in images)
    arraytype = "BH"[max(bitdepth) > 8]

    size = [image.info["size"] for image in images]
    # Currently, fail unless all images same size.
    if len(set(size)) > 1:
        raise NotImplementedError("Cannot cope when sizes differ - sorry!")
    size = size[0]

    # Values per row, of output image
    vpr = out_channels * size[0]

    def weave_row_iter():
        """
        Yield each woven row in turn.
        """
        # The zip call creates an iterator that yields
        # a tuple with each element containing the next row
        # for each of the input images.
        for row_tuple in zip(*(image.rows for image in images)):
            # output row
            row = array(arraytype, [0] * vpr)
            # for each output channel select correct input channel
            for out_channel_i, selection in enumerate(args.channel):
                channel = channel_map[selection]
                # incoming row (make it an array)
                irow = array(arraytype, row_tuple[channel.image])
                n = images[channel.image].info["planes"]
                row[out_channel_i::out_channels] = irow[channel.i :: n]
            yield row

    w = png.Writer(
        size[0],
        size[1],
        greyscale=greyscale,
        alpha=alpha,
        bitdepth=bitdepth,
        interlace=args.interlace,
    )
    w.write(out, weave_row_iter())


def comma_list(s):
    """
    Type and return a list of integers.
    """

    return [int(c) for c in re.findall(r"\d+", s)]


def main(argv=None):
    if argv is None:
        argv = sys.argv
    argv = argv[1:]

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-c",
        "--channel",
        action="append",
        type=comma_list,
        help="list of channels to extract",
    )
    parser.add_argument("--interlace", action="store_true", help="write interlaced PNG")
    parser.add_argument("input", nargs="+")
    args = parser.parse_args(argv)

    if args.channel:
        args.channel = list(itertools.chain(*args.channel))

    return weave(png.binary_stdout(), args)


if __name__ == "__main__":
    main()
